agentmux_srv\backend\rpc/
engine.rs

1// Copyright 2025-2026, AgentMux Corp.
2// SPDX-License-Identifier: Apache-2.0
3
4//! RPC engine: handles incoming RPC requests, dispatches to handlers,
5//! and manages request/response lifecycle with timeouts and streaming.
6//! Port of Go's pkg/wshutil/wshrpc.go (WshRpc struct + handler dispatch).
7
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, Mutex};
14
15use tokio::sync::mpsc;
16use uuid::Uuid;
17
18use super::super::rpc_types::{RpcContext, RpcMessage, RpcOpts, COMMAND_EVENT_RECV};
19
20// ---- Constants (match Go) ----
21
22pub const DEFAULT_TIMEOUT_MS: i64 = 5000;
23const RESP_CH_SIZE: usize = 32;
24
25// ---- Handler types ----
26
27/// Result type for RPC handler responses.
28pub type HandlerResult = Result<Option<serde_json::Value>, String>;
29
30/// A boxed async handler function.
31/// Takes the command data and returns either:
32/// - Ok(Some(value)) for a single response
33/// - Ok(None) for no response
34/// - Err(msg) for an error response
35pub type CommandHandler = Box<
36    dyn Fn(serde_json::Value, RpcContext) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>>
37        + Send
38        + Sync,
39>;
40
41/// A streaming handler that returns a channel of responses.
42pub type StreamHandler = Box<
43    dyn Fn(
44            serde_json::Value,
45            RpcContext,
46        )
47            -> Pin<Box<dyn Future<Output = Result<mpsc::Receiver<HandlerResult>, String>> + Send>>
48        + Send
49        + Sync,
50>;
51
52enum Handler {
53    Call(CommandHandler),
54    #[allow(dead_code)]
55    Stream(StreamHandler),
56}
57
58// ---- RPC Response Handler ----
59
60/// Allows an RPC handler to send responses back to the caller.
61/// Matches Go's `RpcResponseHandler`.
62pub struct RpcResponseHandler {
63    engine: Arc<WshRpcEngine>,
64    req_id: String,
65    #[allow(dead_code)]
66    source: String,
67    canceled: AtomicBool,
68    done: AtomicBool,
69}
70
71impl RpcResponseHandler {
72    /// Send a single response (or streaming chunk).
73    /// Set `done` to true for the final response.
74    pub fn send_response(&self, data: Option<serde_json::Value>, done: bool) {
75        if self.done.load(Ordering::Relaxed) {
76            return;
77        }
78        let msg = RpcMessage {
79            resid: self.req_id.clone(),
80            data,
81            cont: !done,
82            ..Default::default()
83        };
84        if done {
85            self.done.store(true, Ordering::Relaxed);
86        }
87        self.engine.send_output(msg);
88    }
89
90    /// Send an error response.
91    pub fn send_error(&self, err: &str) {
92        if self.done.load(Ordering::Relaxed) {
93            return;
94        }
95        self.done.store(true, Ordering::Relaxed);
96        let msg = RpcMessage {
97            resid: self.req_id.clone(),
98            error: err.to_string(),
99            ..Default::default()
100        };
101        self.engine.send_output(msg);
102    }
103
104    /// Check if the request has been canceled.
105    #[allow(dead_code)]
106    pub fn is_canceled(&self) -> bool {
107        self.canceled.load(Ordering::Relaxed)
108    }
109
110    /// Get the source route ID of the request.
111    #[allow(dead_code)]
112    pub fn get_source(&self) -> &str {
113        &self.source
114    }
115
116    /// Mark this handler as canceled.
117    fn cancel(&self) {
118        self.canceled.store(true, Ordering::Relaxed);
119    }
120
121    /// Finalize: send empty done response if not already done.
122    fn finalize(&self) {
123        if self.done.load(Ordering::Relaxed) {
124            return;
125        }
126        self.send_response(None, true);
127    }
128}
129
130// ---- RPC Request Handler (client-side) ----
131
132/// Tracks an outgoing request and collects responses.
133/// Matches Go's `RpcRequestHandler`.
134#[allow(dead_code)]
135pub struct RpcRequestHandler {
136    req_id: String,
137    resp_rx: mpsc::Receiver<RpcMessage>,
138    last_was_cont: bool,
139}
140
141impl RpcRequestHandler {
142    /// Get the next response. Returns None if the stream is done.
143    #[allow(dead_code)]
144    pub async fn next_response(&mut self) -> Option<Result<serde_json::Value, String>> {
145        if !self.last_was_cont && self.req_id.is_empty() {
146            return None;
147        }
148        match self.resp_rx.recv().await {
149            Some(msg) => {
150                self.last_was_cont = msg.cont;
151                if !msg.error.is_empty() {
152                    Some(Err(msg.error))
153                } else {
154                    Some(Ok(msg.data.unwrap_or(serde_json::Value::Null)))
155                }
156            }
157            None => None,
158        }
159    }
160
161    /// Check if the response stream is complete.
162    #[allow(dead_code)]
163    pub fn is_done(&self) -> bool {
164        !self.last_was_cont
165    }
166
167    /// Get the request ID.
168    #[allow(dead_code)]
169    pub fn req_id(&self) -> &str {
170        &self.req_id
171    }
172}
173
174// ---- RPC Engine ----
175
176struct EngineInner {
177    handlers: HashMap<String, Handler>,
178    pending_responses: HashMap<String, mpsc::Sender<RpcMessage>>,
179    active_handlers: HashMap<String, Arc<RpcResponseHandler>>,
180    #[allow(dead_code)]
181    auth_token: String,
182    rpc_context: Option<RpcContext>,
183}
184
185/// Core RPC engine: handles incoming RPC requests, dispatches to registered
186/// command handlers, and manages request/response lifecycle.
187///
188/// Port of Go's `WshRpc` from pkg/wshutil/wshrpc.go.
189pub struct WshRpcEngine {
190    inner: Mutex<EngineInner>,
191    output_tx: mpsc::UnboundedSender<RpcMessage>,
192}
193
194impl WshRpcEngine {
195    /// Create a new RPC engine.
196    /// Returns the engine and a receiver for outgoing messages.
197    pub fn new() -> (Arc<Self>, mpsc::UnboundedReceiver<RpcMessage>) {
198        let (output_tx, output_rx) = mpsc::unbounded_channel();
199        let engine = Arc::new(Self {
200            inner: Mutex::new(EngineInner {
201                handlers: HashMap::new(),
202                pending_responses: HashMap::new(),
203                active_handlers: HashMap::new(),
204                auth_token: String::new(),
205                rpc_context: None,
206            }),
207            output_tx,
208        });
209        (engine, output_rx)
210    }
211
212    /// Register a call handler (single request → single response).
213    pub fn register_handler(&self, command: &str, handler: CommandHandler) {
214        let mut inner = self.inner.lock().unwrap();
215        inner
216            .handlers
217            .insert(command.to_string(), Handler::Call(handler));
218    }
219
220    /// Register a streaming handler (single request → stream of responses).
221    #[allow(dead_code)]
222    pub fn register_stream_handler(&self, command: &str, handler: StreamHandler) {
223        let mut inner = self.inner.lock().unwrap();
224        inner
225            .handlers
226            .insert(command.to_string(), Handler::Stream(handler));
227    }
228
229    /// Set the authentication token.
230    #[allow(dead_code)]
231    pub fn set_auth_token(&self, token: &str) {
232        let mut inner = self.inner.lock().unwrap();
233        inner.auth_token = token.to_string();
234    }
235
236    /// Get the authentication token.
237    #[allow(dead_code)]
238    pub fn get_auth_token(&self) -> String {
239        let inner = self.inner.lock().unwrap();
240        inner.auth_token.clone()
241    }
242
243    /// Set the RPC context.
244    #[allow(dead_code)]
245    pub fn set_rpc_context(&self, ctx: RpcContext) {
246        let mut inner = self.inner.lock().unwrap();
247        inner.rpc_context = Some(ctx);
248    }
249
250    /// Process an incoming message (from the transport layer).
251    pub fn handle_message(self: &Arc<Self>, msg: RpcMessage) {
252        // Cancel handling
253        if msg.cancel {
254            if !msg.reqid.is_empty() {
255                self.handle_cancel_request(&msg.reqid);
256            }
257            return;
258        }
259
260        // Event handling (special: no response)
261        if msg.command == COMMAND_EVENT_RECV {
262            // Events are handled by the event listener, not via RPC handlers
263            return;
264        }
265
266        // New command (request)
267        if !msg.command.is_empty() {
268            let engine = self.clone();
269            tokio::spawn(async move {
270                engine.handle_request(msg).await;
271            });
272            return;
273        }
274
275        // Response (has resid)
276        if !msg.resid.is_empty() {
277            self.handle_response(msg);
278        }
279    }
280
281    /// Send an RPC command and wait for a single response.
282    #[allow(dead_code)]
283    pub async fn send_command(
284        self: &Arc<Self>,
285        command: &str,
286        data: serde_json::Value,
287        opts: &RpcOpts,
288    ) -> Result<serde_json::Value, String> {
289        let mut handler = self.send_request(command, data, opts)?;
290        match handler.next_response().await {
291            Some(result) => result,
292            None => Err("no response received".to_string()),
293        }
294    }
295
296    /// Send an RPC command and get a request handler for streaming responses.
297    #[allow(dead_code)]
298    pub fn send_request(
299        self: &Arc<Self>,
300        command: &str,
301        data: serde_json::Value,
302        opts: &RpcOpts,
303    ) -> Result<RpcRequestHandler, String> {
304        let req_id = Uuid::new_v4().to_string();
305        let (resp_tx, resp_rx) = mpsc::channel(RESP_CH_SIZE);
306
307        {
308            let mut inner = self.inner.lock().unwrap();
309            inner
310                .pending_responses
311                .insert(req_id.clone(), resp_tx);
312        }
313
314        let timeout = if opts.timeout > 0 {
315            opts.timeout
316        } else {
317            DEFAULT_TIMEOUT_MS
318        };
319        let route = if opts.route.is_empty() {
320            String::new()
321        } else {
322            opts.route.clone()
323        };
324
325        let msg = RpcMessage {
326            command: command.to_string(),
327            reqid: req_id.clone(),
328            timeout,
329            route,
330            data: Some(data),
331            authtoken: self.get_auth_token(),
332            ..Default::default()
333        };
334        self.send_output(msg);
335
336        Ok(RpcRequestHandler {
337            req_id,
338            resp_rx,
339            last_was_cont: true, // assume more data initially
340        })
341    }
342
343    /// Send a fire-and-forget command (no response expected).
344    #[allow(dead_code)]
345    pub fn send_command_no_response(
346        &self,
347        command: &str,
348        data: serde_json::Value,
349        route: &str,
350    ) {
351        let msg = RpcMessage {
352            command: command.to_string(),
353            data: Some(data),
354            route: route.to_string(),
355            authtoken: self.get_auth_token(),
356            ..Default::default()
357        };
358        self.send_output(msg);
359    }
360
361    // ---- Internal ----
362
363    fn send_output(&self, msg: RpcMessage) {
364        let _ = self.output_tx.send(msg);
365    }
366
367    async fn handle_request(self: Arc<Self>, msg: RpcMessage) {
368        let request_start = std::time::Instant::now();
369        let timeout_ms = if msg.timeout > 0 {
370            msg.timeout
371        } else {
372            DEFAULT_TIMEOUT_MS
373        };
374
375        let handler = Arc::new(RpcResponseHandler {
376            engine: self.clone(),
377            req_id: msg.reqid.clone(),
378            source: msg.source.clone(),
379            canceled: AtomicBool::new(false),
380            done: AtomicBool::new(false),
381        });
382
383        // Register the active handler
384        if !msg.reqid.is_empty() {
385            let mut inner = self.inner.lock().unwrap();
386            inner
387                .active_handlers
388                .insert(msg.reqid.clone(), handler.clone());
389        }
390
391        let rpc_context = {
392            let inner = self.inner.lock().unwrap();
393            inner.rpc_context.clone().unwrap_or_default()
394        };
395
396        let data = msg.data.unwrap_or(serde_json::Value::Null);
397        let command = msg.command.clone();
398
399        // Look up handler
400        let has_call;
401        let has_stream;
402        {
403            let inner = self.inner.lock().unwrap();
404            match inner.handlers.get(&command) {
405                Some(Handler::Call(_)) => {
406                    has_call = true;
407                    has_stream = false;
408                }
409                Some(Handler::Stream(_)) => {
410                    has_call = false;
411                    has_stream = true;
412                }
413                None => {
414                    has_call = false;
415                    has_stream = false;
416                }
417            }
418        }
419
420        let dispatch_elapsed = request_start.elapsed();
421
422        if !has_call && !has_stream {
423            handler.send_error(&format!("unknown command: {}", command));
424            self.cleanup_handler(&msg.reqid);
425            return;
426        }
427
428        let timeout_dur = std::time::Duration::from_millis(timeout_ms as u64);
429
430        if has_call {
431            // Call handler: single response with timeout.
432            // Create the future while holding the lock, then drop the lock before awaiting.
433            let handler_start = std::time::Instant::now();
434            let fut = {
435                let inner = self.inner.lock().unwrap();
436                match inner.handlers.get(&command) {
437                    Some(Handler::Call(h)) => h(data.clone(), rpc_context.clone()),
438                    _ => Box::pin(async { Err("handler disappeared".to_string()) }),
439                }
440            };
441            let result = tokio::time::timeout(timeout_dur, fut).await;
442            let handler_elapsed = handler_start.elapsed();
443            let total_elapsed = request_start.elapsed();
444
445            tracing::info!(
446                "[rpc-perf] command={} dispatch={:.2}ms handler={:.2}ms total={:.2}ms",
447                command,
448                dispatch_elapsed.as_secs_f64() * 1000.0,
449                handler_elapsed.as_secs_f64() * 1000.0,
450                total_elapsed.as_secs_f64() * 1000.0,
451            );
452
453            match result {
454                Ok(Ok(resp_data)) => handler.send_response(resp_data, true),
455                Ok(Err(err)) => handler.send_error(&err),
456                Err(_) => handler.send_error(&format!("EC-TIME: timeout ({}ms)", timeout_ms)),
457            }
458        } else {
459            // Stream handler: same pattern — build future under lock, await outside.
460            let fut = {
461                let inner = self.inner.lock().unwrap();
462                match inner.handlers.get(&command) {
463                    Some(Handler::Stream(h)) => h(data.clone(), rpc_context.clone()),
464                    _ => Box::pin(async { Err("handler disappeared".to_string()) }),
465                }
466            };
467            let stream_result = tokio::time::timeout(timeout_dur, fut).await;
468
469            match stream_result {
470                Ok(Ok(mut rx)) => {
471                    // Read streaming responses
472                    loop {
473                        match tokio::time::timeout(timeout_dur, rx.recv()).await {
474                            Ok(Some(Ok(resp_data))) => {
475                                handler.send_response(resp_data, false);
476                            }
477                            Ok(Some(Err(err))) => {
478                                handler.send_error(&err);
479                                break;
480                            }
481                            Ok(None) => {
482                                // Channel closed — stream done
483                                handler.finalize();
484                                break;
485                            }
486                            Err(_) => {
487                                handler.send_error(&format!(
488                                    "EC-TIME: stream timeout ({}ms)",
489                                    timeout_ms
490                                ));
491                                break;
492                            }
493                        }
494                    }
495                }
496                Ok(Err(err)) => handler.send_error(&err),
497                Err(_) => {
498                    handler.send_error(&format!("EC-TIME: timeout ({}ms)", timeout_ms))
499                }
500            }
501        }
502
503        self.cleanup_handler(&msg.reqid);
504    }
505
506    fn handle_response(&self, msg: RpcMessage) {
507        let inner = self.inner.lock().unwrap();
508        if let Some(tx) = inner.pending_responses.get(&msg.resid) {
509            let is_done = !msg.cont;
510            let _ = tx.try_send(msg.clone());
511            if is_done {
512                drop(inner);
513                let mut inner = self.inner.lock().unwrap();
514                inner.pending_responses.remove(&msg.resid);
515            }
516        }
517    }
518
519    fn handle_cancel_request(&self, req_id: &str) {
520        let inner = self.inner.lock().unwrap();
521        if let Some(handler) = inner.active_handlers.get(req_id) {
522            handler.cancel();
523        }
524    }
525
526    fn cleanup_handler(&self, req_id: &str) {
527        if req_id.is_empty() {
528            return;
529        }
530        let mut inner = self.inner.lock().unwrap();
531        inner.active_handlers.remove(req_id);
532    }
533}
534
535// ====================================================================
536// Tests
537// ====================================================================
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[tokio::test]
544    async fn test_register_and_call_handler() {
545        let (engine, mut output_rx) = WshRpcEngine::new();
546
547        engine.register_handler(
548            "echo",
549            Box::new(|data, _ctx| {
550                Box::pin(async move { Ok(Some(data)) })
551            }),
552        );
553
554        let msg = RpcMessage {
555            command: "echo".to_string(),
556            reqid: "req-1".to_string(),
557            data: Some(serde_json::json!({"hello": "world"})),
558            ..Default::default()
559        };
560        engine.handle_message(msg);
561
562        // Collect the response
563        let resp = tokio::time::timeout(
564            std::time::Duration::from_secs(1),
565            output_rx.recv(),
566        )
567        .await
568        .unwrap()
569        .unwrap();
570
571        assert_eq!(resp.resid, "req-1");
572        assert!(!resp.cont);
573        assert_eq!(resp.data, Some(serde_json::json!({"hello": "world"})));
574    }
575
576    #[tokio::test]
577    async fn test_unknown_command_returns_error() {
578        let (engine, mut output_rx) = WshRpcEngine::new();
579
580        let msg = RpcMessage {
581            command: "nonexistent".to_string(),
582            reqid: "req-2".to_string(),
583            ..Default::default()
584        };
585        engine.handle_message(msg);
586
587        let resp = tokio::time::timeout(
588            std::time::Duration::from_secs(1),
589            output_rx.recv(),
590        )
591        .await
592        .unwrap()
593        .unwrap();
594
595        assert_eq!(resp.resid, "req-2");
596        assert!(resp.error.contains("unknown command"));
597    }
598
599    #[tokio::test]
600    async fn test_handler_error_returns_error_response() {
601        let (engine, mut output_rx) = WshRpcEngine::new();
602
603        engine.register_handler(
604            "failme",
605            Box::new(|_data, _ctx| {
606                Box::pin(async move { Err("something went wrong".to_string()) })
607            }),
608        );
609
610        let msg = RpcMessage {
611            command: "failme".to_string(),
612            reqid: "req-3".to_string(),
613            ..Default::default()
614        };
615        engine.handle_message(msg);
616
617        let resp = tokio::time::timeout(
618            std::time::Duration::from_secs(1),
619            output_rx.recv(),
620        )
621        .await
622        .unwrap()
623        .unwrap();
624
625        assert_eq!(resp.error, "something went wrong");
626    }
627
628    #[tokio::test]
629    async fn test_send_command_roundtrip() {
630        let (engine, mut output_rx) = WshRpcEngine::new();
631
632        // Spawn a "server" that echoes responses
633        let engine_clone = engine.clone();
634        tokio::spawn(async move {
635            if let Some(msg) = output_rx.recv().await {
636                // This is the outgoing request — echo it back as a response
637                let resp = RpcMessage {
638                    resid: msg.reqid.clone(),
639                    data: msg.data.clone(),
640                    ..Default::default()
641                };
642                engine_clone.handle_message(resp);
643            }
644        });
645
646        let opts = RpcOpts {
647            timeout: 1000,
648            ..Default::default()
649        };
650        let result = engine
651            .send_command("test", serde_json::json!(42), &opts)
652            .await;
653
654        assert!(result.is_ok());
655        assert_eq!(result.unwrap(), serde_json::json!(42));
656    }
657
658    #[tokio::test]
659    async fn test_stream_handler() {
660        let (engine, mut output_rx) = WshRpcEngine::new();
661
662        engine.register_stream_handler(
663            "counter",
664            Box::new(|_data, _ctx| {
665                Box::pin(async move {
666                    let (tx, rx) = mpsc::channel(8);
667                    tokio::spawn(async move {
668                        for i in 0..3 {
669                            let _ = tx.send(Ok(Some(serde_json::json!(i)))).await;
670                        }
671                        // Channel drops → stream done
672                    });
673                    Ok(rx)
674                })
675            }),
676        );
677
678        let msg = RpcMessage {
679            command: "counter".to_string(),
680            reqid: "req-stream".to_string(),
681            ..Default::default()
682        };
683        engine.handle_message(msg);
684
685        // Collect streaming responses
686        let mut responses = Vec::new();
687        for _ in 0..4 {
688            // 3 data + 1 final empty
689            match tokio::time::timeout(
690                std::time::Duration::from_secs(2),
691                output_rx.recv(),
692            )
693            .await
694            {
695                Ok(Some(resp)) => responses.push(resp),
696                _ => break,
697            }
698        }
699
700        // Should have 3 streaming chunks + 1 final
701        assert!(responses.len() >= 3);
702        // First 3 have cont=true
703        for resp in &responses[..3] {
704            assert!(resp.cont);
705        }
706        // Last one has cont=false (finalize)
707        if responses.len() == 4 {
708            assert!(!responses[3].cont);
709        }
710    }
711
712    #[tokio::test]
713    async fn test_cancel_request() {
714        let (engine, mut output_rx) = WshRpcEngine::new();
715
716        let (started_tx, started_rx) = tokio::sync::oneshot::channel::<()>();
717        engine.register_handler(
718            "slow",
719            Box::new(move |_data, _ctx| {
720                Box::pin(async move {
721                    // Signal that we started
722                    // (can't move started_tx into closure that's called multiple times)
723                    tokio::time::sleep(std::time::Duration::from_secs(10)).await;
724                    Ok(Some(serde_json::json!("done")))
725                })
726            }),
727        );
728
729        // Send command
730        let msg = RpcMessage {
731            command: "slow".to_string(),
732            reqid: "req-cancel".to_string(),
733            timeout: 10000,
734            ..Default::default()
735        };
736        engine.handle_message(msg);
737
738        // Small delay then send cancel
739        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
740        let cancel_msg = RpcMessage {
741            cancel: true,
742            reqid: "req-cancel".to_string(),
743            ..Default::default()
744        };
745        engine.handle_message(cancel_msg);
746
747        // The handler will still time out or complete, but the cancel flag should be set
748        // Just verify we get a response eventually (timeout response)
749        let resp = tokio::time::timeout(
750            std::time::Duration::from_secs(12),
751            output_rx.recv(),
752        )
753        .await;
754        assert!(resp.is_ok());
755        // Clean up to avoid unused variable warning
756        drop(started_tx);
757        drop(started_rx);
758    }
759
760    #[tokio::test]
761    async fn test_send_command_no_response() {
762        let (engine, mut output_rx) = WshRpcEngine::new();
763
764        engine.send_command_no_response("notify", serde_json::json!({"msg": "hi"}), "");
765
766        let msg = tokio::time::timeout(
767            std::time::Duration::from_millis(100),
768            output_rx.recv(),
769        )
770        .await
771        .unwrap()
772        .unwrap();
773
774        assert_eq!(msg.command, "notify");
775        assert!(msg.reqid.is_empty());
776    }
777
778    #[tokio::test]
779    async fn test_auth_token() {
780        let (engine, _output_rx) = WshRpcEngine::new();
781        assert!(engine.get_auth_token().is_empty());
782
783        engine.set_auth_token("my-secret-token");
784        assert_eq!(engine.get_auth_token(), "my-secret-token");
785    }
786
787    #[tokio::test]
788    async fn test_rpc_context() {
789        let (engine, _output_rx) = WshRpcEngine::new();
790
791        let ctx = RpcContext {
792            client_type: "connserver".to_string(),
793            blockid: "blk-1".to_string(),
794            ..Default::default()
795        };
796        engine.set_rpc_context(ctx);
797
798        // The context is passed to handlers
799        engine.register_handler(
800            "checkctx",
801            Box::new(|_data, ctx| {
802                Box::pin(async move {
803                    Ok(Some(serde_json::json!({
804                        "ctype": ctx.client_type,
805                        "blockid": ctx.blockid,
806                    })))
807                })
808            }),
809        );
810
811        let msg = RpcMessage {
812            command: "checkctx".to_string(),
813            reqid: "req-ctx".to_string(),
814            ..Default::default()
815        };
816        engine.handle_message(msg);
817
818        // Output will contain the context
819        // (tested indirectly through handler dispatch)
820    }
821}